{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load schedulers and models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Diffusion pipelines are a collection of interchangeable schedulers and models that can be mixed and matched to tailor a pipeline to a specific use case. The scheduler encapsulates the entire denoising process such as the number of denoising steps and the algorithm for finding the denoised sample. A scheduler is not parameterized or trained so they don't take very much memory. The model is usually only concerned with the forward pass of going from a noisy input to a less noisy sample.\n",
    "\n",
    "This guide will show you how to load schedulers and models to customize a pipeline. You'll use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint throughout this guide, so let's load it first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from diffusers import DiffusionPipeline\n",
    "\n",
    "pipeline = DiffusionPipeline.from_pretrained(\n",
    "    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, use_safetensors=True\n",
    ").to(\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see what scheduler this pipeline uses with the `pipeline.scheduler` attribute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline.scheduler\n",
    "PNDMScheduler {\n",
    "  \"_class_name\": \"PNDMScheduler\",\n",
    "  \"_diffusers_version\": \"0.21.4\",\n",
    "  \"beta_end\": 0.012,\n",
    "  \"beta_schedule\": \"scaled_linear\",\n",
    "  \"beta_start\": 0.00085,\n",
    "  \"clip_sample\": false,\n",
    "  \"num_train_timesteps\": 1000,\n",
    "  \"set_alpha_to_one\": false,\n",
    "  \"skip_prk_steps\": true,\n",
    "  \"steps_offset\": 1,\n",
    "  \"timestep_spacing\": \"leading\",\n",
    "  \"trained_betas\": null\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load a scheduler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Schedulers are defined by a configuration file that can be used by a variety of schedulers. Load a scheduler with the [SchedulerMixin.from_pretrained()](https://huggingface.co/docs/diffusers/main/en/api/schedulers/overview#diffusers.SchedulerMixin.from_pretrained) method, and specify the `subfolder` parameter to load the configuration file into the correct subfolder of the pipeline repository.\n",
    "\n",
    "For example, to load the [DDIMScheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/ddim#diffusers.DDIMScheduler):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import DDIMScheduler, DiffusionPipeline\n",
    "\n",
    "ddim = DDIMScheduler.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", subfolder=\"scheduler\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then you can pass the newly loaded scheduler to the pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline = DiffusionPipeline.from_pretrained(\n",
    "    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", scheduler=ddim, torch_dtype=torch.float16, use_safetensors=True\n",
    ").to(\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare schedulers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Schedulers have their own unique strengths and weaknesses, making it difficult to quantitatively compare which scheduler works best for a pipeline. You typically have to make a trade-off between denoising speed and denoising quality. We recommend trying out different schedulers to find one that works best for your use case. Call the `pipeline.scheduler.compatibles` attribute to see what schedulers are compatible with a pipeline.\n",
    "\n",
    "Let's compare the [LMSDiscreteScheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/lms_discrete#diffusers.LMSDiscreteScheduler), [EulerDiscreteScheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/euler#diffusers.EulerDiscreteScheduler), [EulerAncestralDiscreteScheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/euler_ancestral#diffusers.EulerAncestralDiscreteScheduler), and the [DPMSolverMultistepScheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/multistep_dpm_solver#diffusers.DPMSolverMultistepScheduler) on the following prompt and seed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from diffusers import DiffusionPipeline\n",
    "\n",
    "pipeline = DiffusionPipeline.from_pretrained(\n",
    "    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, use_safetensors=True\n",
    ").to(\"cuda\")\n",
    "\n",
    "prompt = \"A photograph of an astronaut riding a horse on Mars, high resolution, high definition.\"\n",
    "generator = torch.Generator(device=\"cuda\").manual_seed(8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To change the pipelines scheduler, use the [from_config()](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method to load a different scheduler's `pipeline.scheduler.config` into the pipeline.\n",
    "\n",
    "<hfoptions id=\"schedulers\">\n",
    "<hfoption id=\"LMSDiscreteScheduler\">\n",
    "\n",
    "[LMSDiscreteScheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/lms_discrete#diffusers.LMSDiscreteScheduler) typically generates higher quality images than the default scheduler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import LMSDiscreteScheduler\n",
    "\n",
    "pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)\n",
    "image = pipeline(prompt, generator=generator).images[0]\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "</hfoption>\n",
    "<hfoption id=\"EulerDiscreteScheduler\">\n",
    "\n",
    "[EulerDiscreteScheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/euler#diffusers.EulerDiscreteScheduler) can generate higher quality images in just 30 steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import EulerDiscreteScheduler\n",
    "\n",
    "pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)\n",
    "image = pipeline(prompt, generator=generator).images[0]\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "</hfoption>\n",
    "<hfoption id=\"EulerAncestralDiscreteScheduler\">\n",
    "\n",
    "[EulerAncestralDiscreteScheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/euler_ancestral#diffusers.EulerAncestralDiscreteScheduler) can generate higher quality images in just 30 steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import EulerAncestralDiscreteScheduler\n",
    "\n",
    "pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)\n",
    "image = pipeline(prompt, generator=generator).images[0]\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "</hfoption>\n",
    "<hfoption id=\"DPMSolverMultistepScheduler\">\n",
    "\n",
    "[DPMSolverMultistepScheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/multistep_dpm_solver#diffusers.DPMSolverMultistepScheduler) provides a balance between speed and quality and can generate higher quality images in just 20 steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import DPMSolverMultistepScheduler\n",
    "\n",
    "pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n",
    "image = pipeline(prompt, generator=generator).images[0]\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "</hfoption>\n",
    "</hfoptions>\n",
    "\n",
    "<div class=\"flex gap-4\">\n",
    "  <div>\n",
    "    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_lms.png\" />\n",
    "    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">LMSDiscreteScheduler</figcaption>\n",
    "  </div>\n",
    "  <div>\n",
    "    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_discrete.png\" />\n",
    "    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">EulerDiscreteScheduler</figcaption>\n",
    "  </div>\n",
    "</div>\n",
    "<div class=\"flex gap-4\">\n",
    "  <div>\n",
    "    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_ancestral.png\" />\n",
    "    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">EulerAncestralDiscreteScheduler</figcaption>\n",
    "  </div>\n",
    "  <div>\n",
    "    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_dpm.png\" />\n",
    "    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">DPMSolverMultistepScheduler</figcaption>\n",
    "  </div>\n",
    "</div>\n",
    "\n",
    "Most images look very similar and are comparable in quality. Again, it often comes down to your specific use case so a good approach is to run multiple different schedulers and compare the results."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Flax schedulers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To compare Flax schedulers, you need to additionally load the scheduler state into the model parameters. For example, let's change the default scheduler in [FlaxStableDiffusionPipeline](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/text2img#diffusers.FlaxStableDiffusionPipeline) to use the super fast `FlaxDPMSolverMultistepScheduler`.\n",
    "\n",
    "> [!WARNING]\n",
    "> The `FlaxLMSDiscreteScheduler` and `FlaxDDPMScheduler` are not compatible with the [FlaxStableDiffusionPipeline](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/text2img#diffusers.FlaxStableDiffusionPipeline) yet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import numpy as np\n",
    "from flax.jax_utils import replicate\n",
    "from flax.training.common_utils import shard\n",
    "from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler\n",
    "\n",
    "scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(\n",
    "    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n",
    "    subfolder=\"scheduler\"\n",
    ")\n",
    "pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(\n",
    "    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n",
    "    scheduler=scheduler,\n",
    "    variant=\"bf16\",\n",
    "    dtype=jax.numpy.bfloat16,\n",
    ")\n",
    "params[\"scheduler\"] = scheduler_state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then you can take advantage of Flax's compatibility with TPUs to generate a number of images in parallel. You'll need to make a copy of the model parameters for each available device and then split the inputs across them to generate your desired number of images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)\n",
    "prompt = \"A photograph of an astronaut riding a horse on Mars, high resolution, high definition.\"\n",
    "num_samples = jax.device_count()\n",
    "prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)\n",
    "\n",
    "prng_seed = jax.random.PRNGKey(0)\n",
    "num_inference_steps = 25\n",
    "\n",
    "# shard inputs and rng\n",
    "params = replicate(params)\n",
    "prng_seed = jax.random.split(prng_seed, jax.device_count())\n",
    "prompt_ids = shard(prompt_ids)\n",
    "\n",
    "images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images\n",
    "images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Models are loaded from the [ModelMixin.from_pretrained()](https://huggingface.co/docs/diffusers/main/en/api/models/overview#diffusers.ModelMixin.from_pretrained) method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [from_pretrained()](https://huggingface.co/docs/diffusers/main/en/api/models/overview#diffusers.ModelMixin.from_pretrained) reuses files in the cache instead of re-downloading them.\n",
    "\n",
    "Models can be loaded from a subfolder with the `subfolder` argument. For example, the model weights for [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) are stored in the [unet](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/unet) subfolder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import UNet2DConditionModel\n",
    "\n",
    "unet = UNet2DConditionModel.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", subfolder=\"unet\", use_safetensors=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "They can also be directly loaded from a [repository](https://huggingface.co/google/ddpm-cifar10-32/tree/main)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import UNet2DModel\n",
    "\n",
    "unet = UNet2DModel.from_pretrained(\"google/ddpm-cifar10-32\", use_safetensors=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To load and save model variants, specify the `variant` argument in [ModelMixin.from_pretrained()](https://huggingface.co/docs/diffusers/main/en/api/models/overview#diffusers.ModelMixin.from_pretrained) and [ModelMixin.save_pretrained()](https://huggingface.co/docs/diffusers/main/en/api/models/overview#diffusers.ModelMixin.save_pretrained)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import UNet2DConditionModel\n",
    "\n",
    "unet = UNet2DConditionModel.from_pretrained(\n",
    "    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", subfolder=\"unet\", variant=\"non_ema\", use_safetensors=True\n",
    ")\n",
    "unet.save_pretrained(\"./local-unet\", variant=\"non_ema\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use the `torch_dtype` argument in [from_pretrained()](https://huggingface.co/docs/diffusers/main/en/api/models/overview#diffusers.ModelMixin.from_pretrained) to specify the dtype to load a model in."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import AutoModel\n",
    "\n",
    "unet = AutoModel.from_pretrained(\n",
    "    \"stabilityai/stable-diffusion-xl-base-1.0\", subfolder=\"unet\", torch_dtype=torch.float16\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also use the [torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) method to convert to the specified dtype on the fly. It converts *all* weights unlike the `torch_dtype` argument that respects the `_keep_in_fp32_modules`. This is important for models whose layers must remain in fp32 for numerical stability and best generation quality (see example [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374))."
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
